In [1]:
'''
NTK이론을 바탕으로 Fourier-featuring이 MLP로 하여금 high frequency부분이 잘 학습할 수 있게 한다


**Coordinate-based MLP
input : pixel 정보
output : 여러가지   
Fully connected layer

** MLP가 spectral bias 에 의해 high frequency부분을 잘 학습하지 못해
Positional Encoding 중의 하나인 Fourier-featuring을 사용하면 잘 학습할 수 있다.

왜 Fourier-featuring을 사용하면 잘 학습?  
A :이걸 NTK로 증명한 것이 이 논문의 contribution입니다 

'''
Out[1]:
'\nNTK이론을 바탕으로 Fourier-featuring이 MLP로 하여금 high frequency부분이 잘 학습할 수 있게 한다\n\n\n**Coordinate-based MLP\ninput : pixel 정보\noutput : 여러가지   \nFully connected layer\n\n** MLP가 spectral bias 에 의해 high frequency부분을 잘 학습하지 못해\nPositional Encoding 중의 하나인 Fourier-featuring을 사용하면 잘 학습할 수 있다.\n\n왜 Fourier-featuring을 사용하면 잘 학습?  \nA :이걸 NTK로 증명한 것이 이 논문의 contribution입니다 \n\n'
In [2]:
'''
***Kernel Method
Input space의 데이터를 선형분류가 가능한 고차원 공간으로 mapping한 뒤 두 범주를 분류하는 초평면을 찾는다

***kernel regression
linear regression과 달리 비선형 함수 같은 것을 regression 방법론
why? 자연상에 존재하는 비선형성을 찾는 것  ==이므로 kernel regrssion과 동치다

x는 새로운 데이터 xi가 원래 있는 데이터이고 원래 있는데이터를 학습해 새로들어온 데이터가 얼마나 유사한지 계산해서 weighted sum으로 분류하는 과정

***NTK
kernel regression을 이용해 Neural Network의 작동 원리를 설명하려는 방법론
목표 : MLP를 Kernel 함수꼴로 재정의하기

positon encoding이 왜 수렴에 효과적인지 증명하기 위해서 MLP를 kernel 함수로 재정의하는데 그 떄 사용되는게 NTK이다
'''
'''
MLP를 Kernel함수로 고치는 것이 왜 필요한가
1. MLP가 High frequency를 잘 학습하지 못하는 이유를 이해
-> 증명  : 수렴 속도는 그 성분의 eigenvalue에 의해 결정된다.
QT(yˆ(t) −y)≈QT  I−e−ηKty−y=−e−ηΛtQTy.



2. Positional Encoding같은 fourier-featuring을 첨가했을 때 학습이 잘 되는 이유를 이해
-> 증명 : Fourier-featuring을 첨가하면 데이터는 stationary(같은 패턴이 반복되는 성질)한 성질을 가짐 & MLP는 convolution된다
잊혀질수 있는데이터가 계속 반복되기 때문에 충분히 학습할 수 있는 여지를 준다
'''
Out[2]:
'\nMLP를 Kernel함수로 고치는 것이 왜 필요한가\n1. MLP가 High frequency를 잘 학습하지 못하는 이유를 이해\n-> 증명  : 수렴 속도는 그 성분의 eigenvalue에 의해 결정된다.\n\n2.\n'
In [48]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from tqdm import tqdm
import sys
import imageio
import math
print(os.getcwd())
device= torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)
/ssd1/Fourier
cuda
In [49]:
def psnr(label, outputs, max_val=1.):
    label = label.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()
    img_diff = outputs - label
    rmse = math.sqrt(np.mean((img_diff)**2))
    if rmse == 0: # label과 output이 완전히 일치하는 경우
        return 100
    else:
        psnr = 20 * math.log10(max_val/rmse)
        return psnr
In [50]:
img=plt.imread('15677707699_d9d67acf9d_b.jpg')
img=img[...,:3] /255.
print(img.shape)
print(img.dtype)
c=[img.shape[0]//2,img.shape[1]//2]
r=256
img = img[c[0] - r:c[0] + r, c[1] - r:c[1] + r]
img.shape
(689, 1024, 3)
float64
Out[50]:
(512, 512, 3)
In [51]:
target=torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).to(device)
print(target)
print(target.dtype)
#이미지 확인
sample=np.array(target[0].permute(1,2,0).cpu())
plt.imshow(sample)




coords=np.linspace(0,1,target.shape[2],endpoint=False)
print(coords.shape)
xy_grid=np.stack(np.meshgrid(coords,coords),-1)
print(xy_grid.shape)
xy_grid = torch.tensor(xy_grid).unsqueeze(0).permute(0, 3, 1, 2).float().contiguous().to(device)
print(xy_grid.shape)
tensor([[[[0.1922, 0.1373, 0.1176,  ..., 0.1765, 0.1882, 0.1961],
          [0.1608, 0.1569, 0.1451,  ..., 0.1765, 0.2000, 0.1961],
          [0.1725, 0.1804, 0.1725,  ..., 0.1961, 0.2039, 0.1804],
          ...,
          [0.2078, 0.1765, 0.1647,  ..., 0.7451, 0.6941, 0.6471],
          [0.1961, 0.1686, 0.1529,  ..., 0.8627, 0.8118, 0.7569],
          [0.1922, 0.1725, 0.1529,  ..., 0.9098, 0.8784, 0.8314]],

         [[0.2275, 0.1961, 0.2118,  ..., 0.2118, 0.2000, 0.1961],
          [0.1804, 0.2078, 0.2235,  ..., 0.2118, 0.2118, 0.2078],
          [0.1765, 0.2157, 0.2392,  ..., 0.2196, 0.2196, 0.2039],
          ...,
          [0.2157, 0.1961, 0.2039,  ..., 0.5882, 0.5725, 0.5529],
          [0.1922, 0.1843, 0.1922,  ..., 0.6902, 0.6745, 0.6588],
          [0.1882, 0.1882, 0.2039,  ..., 0.7333, 0.7373, 0.7373]],

         [[0.1216, 0.1059, 0.1098,  ..., 0.0902, 0.0863, 0.0863],
          [0.0627, 0.1059, 0.1137,  ..., 0.0980, 0.1059, 0.1020],
          [0.0549, 0.1020, 0.1137,  ..., 0.1176, 0.1216, 0.1020],
          ...,
          [0.1294, 0.1059, 0.0980,  ..., 0.4863, 0.4510, 0.4039],
          [0.1098, 0.0941, 0.0863,  ..., 0.5843, 0.5490, 0.5020],
          [0.1059, 0.0980, 0.0941,  ..., 0.6118, 0.5961, 0.5647]]]],
       device='cuda:0', dtype=torch.float64)
torch.float64
(512,)
(512, 512, 2)
torch.Size([1, 2, 512, 512])

NO mapping¶

MLP¶

In [52]:
class MLP(nn.Module) :
    def __init__(self) :
        super().__init__()

        self.linear=nn.Sequential(nn.Linear(524288,3),
                                  nn.ReLU(),
                                  nn.Linear(3,3),
                                  nn.ReLU(),
                                  nn.Linear(3,786432),
                                  nn.ReLU()
                                  )
    
    def forward(self,x) :
        x=torch.flatten(x,start_dim=1)
        x=self.linear(x)
        return x
        
In [53]:
model=MLP()
model=model.to(device)
optimizer=optim.Adam(list(model.parameters()),lr=1e-4)
generated=model(xy_grid)

generated=generated.reshape(1,3,512,512)
print(generated.dtype)
print(target.dtype)
print(target.shape)
print(nn.MSELoss()(generated,target))
torch.float32
torch.float64
torch.Size([1, 3, 512, 512])
tensor(0.2685, device='cuda:0', dtype=torch.float64,
       grad_fn=<MseLossBackward0>)
In [54]:
model=MLP()
model=model.to(device)
optimizer=optim.Adam(list(model.parameters()),lr=1e-4)
e1_loss,e1_psnr=[],[]
for epoch in tqdm(range(2000)) :
    optimizer.zero_grad()
    generated=model(xy_grid)
    generated=generated.reshape(1,3,512,512)
    #loss=nn.L1Loss()(target.detach(),generated)
    loss = nn.MSELoss()(generated,target.float())
    e1_loss.append(loss.item())
    loss.backward()
    e_psnr=psnr(generated,target)
    e1_psnr+=[e_psnr]
    optimizer.step()

    if epoch%500==0 :
        print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
        plt.imshow(generated[0].permute(1,2,0).cpu().detach().numpy())
        plt.show()
  0%|          | 0/2000 [00:00<?, ?it/s]
Epoch 0, loss = 0.269
 25%|██▍       | 497/2000 [00:03<00:10, 147.98it/s]
Epoch 500, loss = 0.254
 49%|████▉     | 985/2000 [00:07<00:06, 149.80it/s]
Epoch 1000, loss = 0.241
 74%|███████▍  | 1488/2000 [00:10<00:03, 148.90it/s]
Epoch 1500, loss = 0.231
100%|██████████| 2000/2000 [00:14<00:00, 139.42it/s]
In [58]:
y=xy_grid.cpu()
y=y[0].numpy()
plt.figure(figsize=(16,16))
for idx in range(2) :
    t=y[idx]
    plt.subplot(1,2,idx+1,xticks=[],yticks=[])
    plt.imshow(t)

CNN-MLP¶

In [59]:
class CNN_MLP(nn.Module) :
    def __init__(self) :
        super().__init__()
        self.mlp=nn.Sequential(nn.Conv2d(2,256,kernel_size=1,padding=0),
                               nn.ReLU(),
                               nn.BatchNorm2d(256),

                               nn.Conv2d(256,256,kernel_size=1,padding=0),
                               nn.ReLU(),
                               nn.BatchNorm2d(256),

                               nn.Conv2d(256,256,kernel_size=1,padding=0),
                               nn.ReLU(),
                               nn.BatchNorm2d(256),

                               nn.Conv2d(256,3,kernel_size=1,padding=0),
                               nn.Sigmoid()
                               )
    
    def forward(self,x) :
        x=self.mlp(x)
        return x
        
In [60]:
model=CNN_MLP()
model.to(device)
optimizer=optim.Adam(list(model.parameters()),lr=1e-4)
e2_loss=[]
e2_psnr=[]
for epoch in tqdm(range(2000)):
    optimizer.zero_grad()

    generated = model(xy_grid)
    
    
    #loss = nn.L1Loss()(generated,target)
    loss = nn.MSELoss()(generated,target.float())
    e2_loss.append(loss.item())
    e_psnr=psnr(generated,target)
    e2_psnr+=[e_psnr]
    loss.backward()
    optimizer.step()

    if epoch % 500 == 0:
      print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
      plt.imshow(generated[0].permute(1,2,0).cpu().detach().numpy())
      plt.show()
  0%|          | 0/2000 [00:00<?, ?it/s]
Epoch 0, loss = 0.084
 10%|█         | 200/2000 [00:07<01:02, 28.57it/s]
Epoch 200, loss = 0.014
 20%|█▉        | 398/2000 [00:14<00:56, 28.50it/s]
Epoch 400, loss = 0.010
 30%|███       | 600/2000 [00:21<00:49, 28.41it/s]
Epoch 600, loss = 0.009
 40%|███▉      | 798/2000 [00:28<00:42, 28.45it/s]
Epoch 800, loss = 0.009
 50%|█████     | 1000/2000 [00:35<00:35, 28.50it/s]
Epoch 1000, loss = 0.008
 60%|█████▉    | 1198/2000 [00:42<00:28, 28.46it/s]
Epoch 1200, loss = 0.008
 70%|███████   | 1400/2000 [00:50<00:21, 28.47it/s]
Epoch 1400, loss = 0.008
 80%|███████▉  | 1598/2000 [00:57<00:14, 28.46it/s]
Epoch 1600, loss = 0.008
 90%|█████████ | 1800/2000 [01:04<00:07, 28.37it/s]
Epoch 1800, loss = 0.008
100%|██████████| 2000/2000 [01:11<00:00, 27.87it/s]

mapping¶

In [61]:
def mapping(grid,input_channel,mapping_size,scale) :
    b=torch.randn((input_channel,mapping_size))*scale
    batches, channels, width,height=grid.shape
    x = grid.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
    x = x @ b.to(x.device)
    x = x.view(batches, width, height, mapping_size)
    x = x.permute(0, 3, 1, 2)
    x = 2 * np.pi * x
    return torch.cat([torch.sin(x), torch.cos(x)], dim=1)


class Gaussian(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp=nn.Sequential(nn.Conv2d(256,256,1,padding=0),
                               nn.ReLU(),
                               nn.BatchNorm2d(256),
                               
                               nn.Conv2d(256,256,1,padding=0),
                               nn.ReLU(),
                               nn.BatchNorm2d(256),
                               
                               nn.Conv2d(256,256,1,padding=0),
                               nn.ReLU(),
                               nn.BatchNorm2d(256),
                               
                               nn.Conv2d(256,3,1,padding=0),
                               nn.Sigmoid()
                               )

    def forward(self, x):
        x=self.mlp(x)
        return x
b=torch.randn(2,5)
print(b)
b=torch.randn((2,128))*10

b.shape
tensor([[ 0.2820,  1.0227,  1.6649, -0.6847,  0.1256],
        [-0.3469, -1.0175, -1.7110, -0.4972,  1.5822]])
Out[61]:
torch.Size([2, 128])
In [62]:
x=mapping(xy_grid,2,128,10)

print(x.shape[1])
y=x[0]
plt.figure(figsize=(64,64))
for idx in range(256) :
    t=y[idx]
    t=t.cpu().numpy()
    plt.subplot(16,16,idx+1,xticks=[],yticks=[])
    plt.imshow(t)
256
In [63]:
x=mapping(xy_grid,2,128,10)
print(x.shape)
model=Gaussian().to(device)
optimizer=optim.Adam(list(model.parameters()),lr=1e-4)
e3_loss,e3_psnr=[],[]
for epoch in tqdm(range(2000)) :
    optimizer.zero_grad()
    pred=model(x)
    #loss=nn.L1Loss()(pred,target)
    loss = nn.MSELoss()(pred,target.float())
    e3_loss.append(loss.item())
    e_psnr=psnr(pred,target)
    e3_psnr+=[e_psnr]
    loss.backward()
    optimizer.step()


    if epoch%200==0 :
        print(f'Epoch {epoch}, loss = {float(loss)}')
        plt.imshow(pred[0].permute(1,2,0).cpu().detach().numpy())
        plt.show()
torch.Size([1, 256, 512, 512])
  0%|          | 0/2000 [00:00<?, ?it/s]
Epoch 0, loss = 0.08516941964626312
 10%|█         | 200/2000 [00:06<01:00, 29.78it/s]
Epoch 200, loss = 0.006859270390123129
 20%|█▉        | 398/2000 [00:13<00:53, 29.67it/s]
Epoch 400, loss = 0.004567769356071949
 30%|███       | 600/2000 [00:20<00:47, 29.73it/s]
Epoch 600, loss = 0.0035709913354367018
 40%|███▉      | 798/2000 [00:27<00:40, 29.54it/s]
Epoch 800, loss = 0.002985275350511074
 50%|████▉     | 999/2000 [00:34<00:33, 29.53it/s]
Epoch 1000, loss = 0.002562065375968814
 60%|██████    | 1200/2000 [00:41<00:27, 29.59it/s]
Epoch 1200, loss = 0.0022233647760003805
 70%|██████▉   | 1398/2000 [00:48<00:20, 29.67it/s]
Epoch 1400, loss = 0.0019457068992778659
 80%|████████  | 1600/2000 [00:55<00:13, 29.63it/s]
Epoch 1600, loss = 0.0017326179658994079
 90%|████████▉ | 1799/2000 [01:02<00:06, 29.67it/s]
Epoch 1800, loss = 0.0015709571307525039
100%|██████████| 2000/2000 [01:09<00:00, 28.95it/s]
In [64]:
plt.subplot(211)
#plt.plot(range(2000),e1_loss)
plt.plot(range(2000),e2_loss)
plt.plot(range(2000),e3_loss)
plt.xlabel('Epoch')
plt.ylabel('Error')

plt.subplot(212)
#plt.plot(range(2000),e1_psnr)
plt.plot(range(2000),e2_psnr)
plt.plot(range(2000),e3_psnr)
plt.xlabel('Epoch')
plt.ylabel('PSNR')
Out[64]:
Text(0, 0.5, 'PSNR')